/*
 * @Author: leox_tian
 * @Date: 2021-10-20 13:55:57
 * @Description: file content
 */
/**
 * @file dbscan.h
 * @author tianchangxin
 * @brief Implementation of DBSCAN algorithm from https://zh.wikipedia.org/wiki/DBSCAN
 */
#pragma once
#include <Eigen/Core>
#include <Eigen/Dense>
#include <algorithm>
#include <deque>
#include <iostream>
#include <vector>

using std::deque;
using std::vector;

struct Point2d
{
    float x = 0.0f;
    float y = 0.0f;
};

// namespace TBD
// {

/**
 * @brief Density-based spatial clustering of applications with noise.
 * implementation of DBSCAN algorithm from https://zh.wikipedia.org/wiki/DBSCAN.
 * using distance matrix to reduce repeated computation of distance, which would increase space complexity o(n^2).
 */
template <class T> class DBScan
{
  private:
    // minimum number of points in a cluster
    int _min_num;

    // maximum distance threshold of two points
    double _eps;

    // cluster attributes of each point
    vector<int> _labels;

    // save the distance between all points
    Eigen::MatrixXd _distance_mat;

    // TODO: you should define the distance function of your struct/class
    /**
     * @brief Get the Distance object
     * @param p1 object/point 1
     * @param p2 object/point 2
     * @return double 
     */
    inline double GetDistance(const Eigen::Vector2d& p1, const Eigen::Vector2d& p2)
    {
        return (p1 -p2).norm();
    }
    inline double GetDistance(const Point2d& p1, const Point2d& p2)
    {
        return sqrt(std::pow((p1.x - p2.x), 2) + pow((p1.y - p2.y), 2));
    }

    /**
     * @brief core function of dbscan
     * @param all_pts, all points/objects
     * @return int, number of clusters
     */
    int Run(const vector<T> &all_pts);

    /**
     * @brief find neighbors within a fixed distance, and generate a cluster
     * @param pt_idx, point index of all points
     * @param cluster_idx, cluster index
     * @return bool, true:cluster is generated, false:point/object is noise
     */
    bool ExpandCluster(int pt_idx, int &cluster_idx);

  public:
    /**
     * @brief Construct a new DBScan object
     * @param eps, define maximum distance threshold of two points
     * @param min_num, define minimum number of points in a cluster
     */
    DBScan(float eps, int min_num) : _eps(eps), _min_num(min_num){}

    /**
     * @brief Destroy the DBScan object
     */
    ~DBScan(){}

    /**
     * @brief Get the Clusters object
     * @param all_pts, all points/objects
     * @return vector<vector<T>>, 0:noise points, 1-n:cluster points
     */
    vector<vector<T>> GetClusters(std::vector<T> all_pts);

    /**
     * @brief Get the Labels of object
     * @return vector<int>, get the cluster label corresponds to the index of each object
     */
    vector<int> GetLabels(){ return _labels;}
};

template <typename T> int DBScan<T>::Run(const vector<T> &all_pts)
{
    int size = all_pts.size();
    int cluster_idx = 1;

    // -1:noise, 0:unlabel, >1:cluster index
    _labels.resize(size, 0);

    // 1. calculate the distance between each two points, opt(using OpenMP if data is big)
    _distance_mat = Eigen::MatrixXd::Zero(size, size);
// #pragma omp parallel for schedule(runtime)
    for (int i = 0; i < size; ++i)
    {
        for (int j = i; j < size; ++j)
        {
            if (i != j)
            {
                _distance_mat(i, j) = GetDistance(all_pts.data()[i], all_pts.data()[j]);
                _distance_mat(j, i) = _distance_mat(i, j);
            }
        }
    }
    // std::cout << _distance_mat << std::endl;

    // 2. do clustering
    for (size_t i = 0; i < size; i++)
    {
        if (_labels[i] != 0)
            continue;
        ExpandCluster(i, cluster_idx);
    }

    return cluster_idx - 1;
}

template <typename T> bool DBScan<T>::ExpandCluster(int pt_idx, int &cluster_idx)
{
    // 1.region query
    deque<int> seeds_idx;
    seeds_idx.emplace_back(pt_idx);
    for (size_t col = 0; col < _distance_mat.cols(); col++)
    {
        if (col == pt_idx) continue;
        if (_distance_mat(pt_idx, col) < _eps)
        {
            seeds_idx.emplace_back(col);
        }
    }

    // 2.check point numbers of neighbors, whether its noise
    if (seeds_idx.size() < _min_num)
    {
        _labels[pt_idx] = -1;
        return false;
    }
    // 3.label point
    for (size_t i = 0; i < seeds_idx.size(); i++)
    {
        _labels[seeds_idx[i]] = cluster_idx;
    }

    // 4.do neighbors clustering
    seeds_idx.pop_front();
    while (!seeds_idx.empty())
    {
        auto &row = seeds_idx.front();
        // region query
        vector<int> temp_idx;
        temp_idx.reserve(_distance_mat.cols());
        for (size_t col = 0; col < _distance_mat.cols(); col++)
        {
            if (_distance_mat(row, col) <= _eps)
                temp_idx.emplace_back(col);
        }
        // check point numbers of neighbors, whether its noise
        if (temp_idx.size() > _min_num)
        {
            // label point
            for (size_t i = 0; i < temp_idx.size(); i++)
            {
                // if already label, abort
                if (_labels[temp_idx[i]] >= 1)
                    continue;
                if (_labels[temp_idx[i]] == 0)
                    seeds_idx.emplace_back(temp_idx[i]);
                _labels[temp_idx[i]] = cluster_idx;
            }
        }

        seeds_idx.pop_front();
    }

    cluster_idx++;
    return true;
}
template <typename T> vector<vector<T>> DBScan<T>::GetClusters(std::vector<T> all_pts)
{
    int clusters_num = this->Run(all_pts);
    vector<vector<T>> result(clusters_num + 1);

    for (size_t i = 0; i < _labels.size(); i++)
    {
        auto &label = _labels[i];
        if (label < 1)
        {
            // noise point, index=0
            result[label + 1].emplace_back(all_pts[i]);
            continue;
        }

        result[label].emplace_back(all_pts[i]);
    }

    return result;
}
// } // namespace TBD